# ============================================================
# Thomas König & Stefan Eschenwecker
# The European Court of Justice and legal European integration
# 
# This script produces Figure 4 in the main text.
# ============================================================

# load packages
library(rstan)
library(brms)
library(tidyverse)

# load data
CourtData <- read.csv("Court/Data/PreparedCourtData.csv") %>%
  filter(all_not_applic == 0)

# load models
RegModel <- readRDS("Court/Models/UnequalPrecision.rds")

# Predicted outcomes incorporating case-specific uncertainty + stochastic uncertainty

# define outcome labels
OutcomeLabels <- c("Preserve Sovereignty", "Ambivalent", "More Europe")

# define scenarios
supra_signals <- c("Strong PS", "Weak PS", "Uninfo", "Cont", "Weak ME", "Strong ME")
ms_signal_conditions <- c("-2", "obs", "+2")
ms_signal_intensities <- c("obs", "+2")

ScenarioData <- expand_grid(
  supra_signal = supra_signals,
  net_ms_pref_std = ms_signal_conditions,
  ms_signal_int_std = ms_signal_intensities
) %>%
  mutate(supra_signal_int = case_when(
    supra_signal == "Uninfo" ~ "No Intensity",
    supra_signal %in% c("Weak PS", "Weak ME") ~ "Weak Intensity",
    supra_signal %in% c("Cont", "Strong PS", "Strong ME") ~ "Strong Intensity",
    .default = NA_character_
  ))

# remove implausible scenarios
ScenarioData <- ScenarioData %>%
  filter(!(
    (net_ms_pref_std == "obs" & ms_signal_int_std != "obs") |
      (ms_signal_int_std == "obs" & net_ms_pref_std == "+2") |
      (ms_signal_int_std == "obs" & net_ms_pref_std == "-2")
  ))


# helper function to create scenarios within loop
generate_scenario <- function(supra, ms_cond, euint, msint, base_data) {
  scenario <- base_data
  scenario$supra_signal <- supra
  scenario$supra_signal_int <- euint

  if (ms_cond == "obs") {
    scenario$net_ms_pref_std <- base_data$net_ms_pref_std
    scenario$ms_signal_int_std <- base_data$ms_signal_int_std
  } else {
    shift_pos <- as.numeric(ms_cond)
    shift_int <- as.numeric(msint)
    scenario$net_ms_pref_std <- base_data$net_ms_pref_std + shift_pos
    scenario$ms_signal_int_std <- base_data$ms_signal_int_std + shift_int
  }

  return(scenario)
}

# list to store results
PredictedResults <- list()

# loop over scenarios and save predictions in list
# (array: ndraws x nobs)
# (Hint: takes some time to run)
set.seed(1408)
for (i in seq_len(nrow(ScenarioData))) {
  supra <- ScenarioData$supra_signal[i]
  ms_cond <- ScenarioData$net_ms_pref_std[i]
  euint <- ScenarioData$supra_signal_int[i]
  msint <- ScenarioData$ms_signal_int_std[i]

  scenario_name <- paste(supra, ms_cond, sep = "_")
  scenario_data <- generate_scenario(supra, ms_cond, euint, msint, CourtData)

  PredictedResults[[scenario_name]] <- predict(RegModel,
    re_formula = NULL,
    newdata = scenario_data,
    summary = F
  )
}

# define function to calculate mode outcome prediction
mode <- function(x) {
  ux <- unique(x)
  ux[which.max(tabulate(match(x, ux)))]
}

# get predictions for each observations by taking mode of draws
# and store in single dataframe (long format)
# (Hint: takes some time to run)
TidyResults <- PredictedResults %>%
  imap_dfr(~ {
    modes <- apply(.x, 2, mode)
    if (is.list(modes)) {
      modes <- unlist(modes)
    }
    tibble(predicted_outcome = modes, scenario = .y)
  })


# define color fill for plot
FillCols <- c("darkred", "orange", "aquamarine3")

# create plot
PredPlot <- TidyResults %>%
  separate(scenario, into = c("supra_pos", "ms_pos"), sep = "_") %>%
  group_by(predicted_outcome, supra_pos, ms_pos) %>%
  summarise(share = n() / nrow(CourtData), .groups = "drop") %>%
  mutate(predicted_outcome = factor(case_when(
    predicted_outcome == 1 ~ "PS",
    predicted_outcome == 2 ~ "Ambi",
    predicted_outcome == 3 ~ "ME"
  ), levels = c("PS", "Ambi", "ME"))) %>%
  complete(predicted_outcome, supra_pos, ms_pos) %>%
  mutate(
    share = ifelse(is.na(share), 0, share),
    supra_pos = factor(supra_pos,
      levels = c(
        "Strong PS", "Weak PS", "Uninfo",
        "Cont", "Weak ME", "Strong ME"
      )
    ),
    ms_pos = factor(ms_pos,
      levels = c("-2", "obs", "+2")
    )
  ) %>%
  ggplot() +
  geom_col(aes(x = predicted_outcome, y = share, fill = predicted_outcome)) +
  facet_grid(
    cols = vars(ms_pos),
    rows = vars(supra_pos),
    labeller = labeller(
      supra_pos = c(
        "Strong PS" = "Both PS",
        "Weak PS" = "Single PS",
        "Uninfo" = "Ambivalent",
        "Cont" = "Contradictory",
        "Weak ME" = "Single ME",
        "Strong ME" = "Both ME"
      ),
      ms_pos = c(
        "-2" = "-2 SD",
        "obs" = "Observed Net Pref",
        "+2" = "+2 SD"
      )
    )
  ) +
  scale_y_continuous(labels = scales::percent_format()) +
  scale_fill_manual(values = FillCols) +
  xlab("Predicted Outcome") +
  ylab("Share") +
  theme_bw() +
  theme(
    legend.position = "none",
    axis.text = element_text(size = 10),
    axis.title = element_text(size = 12),
    strip.text = element_text(size = 10),
    strip.text.y.right = element_text(angle = 0)
  )

# save plot
ggsave(plot = PredPlot, device = "pdf",
       filename = "Predictions.pdf",
       path = "Court/Plots/", height = 10, width = 9)
